{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sad_construct import *\n",
    "from sad_read_cifar10 import get_cifar10\n",
    "(x_train,y_train),(x_test,y_test)=get_cifar10()\n",
    "x_train=x_train.reshape(-1,3*32*32)\n",
    "x_test=x_test.reshape(-1,3*32*32)\n",
    "cifar10_train_data = np.append(x_train, y_train, axis=1)\n",
    "cifar10_test_data = np.append(x_test, y_test, axis=1)\n",
    "cifar10_shape=(3,32,32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No file named cnn_cifar10.npy\n",
      "epoch 1:\n",
      " Loss:2.013303216 \tAccuracy: 0.3902 [0.462, 0.534, 0.157, 0.381, 0.125, 0.148, 0.536, 0.497, 0.579, 0.483]\n",
      "epoch 2:\n",
      " Loss:1.600409321 \tAccuracy: 0.4738 [0.469, 0.675, 0.16, 0.336, 0.378, 0.492, 0.552, 0.494, 0.748, 0.434]\n",
      "epoch 3:\n",
      " Loss:1.410191887 \tAccuracy: 0.5276 [0.607, 0.659, 0.37, 0.166, 0.406, 0.46, 0.739, 0.516, 0.767, 0.586]\n",
      "epoch 4:\n",
      " Loss:1.284979028 \tAccuracy: 0.5636 [0.534, 0.707, 0.576, 0.467, 0.45, 0.355, 0.677, 0.576, 0.745, 0.549]\n",
      "epoch 5:\n",
      " Loss:1.181551038 \tAccuracy: 0.5596 [0.453, 0.622, 0.289, 0.27, 0.65, 0.506, 0.715, 0.792, 0.617, 0.682]\n",
      "epoch 6:\n",
      " Loss:1.100592369 \tAccuracy: 0.6154 [0.6, 0.692, 0.483, 0.382, 0.615, 0.633, 0.611, 0.649, 0.709, 0.78]\n",
      "epoch 7:\n",
      " Loss:1.034263175 \tAccuracy: 0.6221 [0.711, 0.859, 0.446, 0.447, 0.521, 0.414, 0.858, 0.652, 0.738, 0.575]\n",
      "epoch 8:\n",
      " Loss:0.978070840 \tAccuracy: 0.6298 [0.712, 0.773, 0.692, 0.277, 0.572, 0.55, 0.646, 0.649, 0.798, 0.629]\n",
      "epoch 9:\n",
      " Loss:0.933537398 \tAccuracy: 0.6417 [0.641, 0.744, 0.598, 0.368, 0.392, 0.543, 0.826, 0.746, 0.876, 0.683]\n",
      "epoch 10:\n",
      " Loss:0.888729485 \tAccuracy: 0.6531 [0.696, 0.774, 0.593, 0.43, 0.441, 0.695, 0.671, 0.719, 0.797, 0.715]\n",
      "epoch 11:\n",
      " Loss:0.851064132 \tAccuracy: 0.6645 [0.795, 0.685, 0.499, 0.528, 0.541, 0.658, 0.68, 0.746, 0.734, 0.779]\n",
      "epoch 12:\n",
      " Loss:0.819274257 \tAccuracy: 0.6693 [0.724, 0.822, 0.41, 0.618, 0.608, 0.572, 0.739, 0.751, 0.795, 0.654]\n",
      "epoch 13:\n",
      " Loss:0.786782960 \tAccuracy: 0.6795 [0.725, 0.72, 0.616, 0.419, 0.531, 0.589, 0.842, 0.748, 0.761, 0.844]\n",
      "epoch 14:\n",
      " Loss:0.757388818 \tAccuracy: 0.6757 [0.741, 0.795, 0.576, 0.309, 0.659, 0.647, 0.829, 0.688, 0.802, 0.711]\n",
      "epoch 15:\n",
      " Loss:0.729849132 \tAccuracy: 0.6663 [0.732, 0.711, 0.512, 0.553, 0.75, 0.53, 0.611, 0.642, 0.874, 0.748]\n",
      "epoch 16:\n",
      " Loss:0.707035686 \tAccuracy: 0.6977 [0.731, 0.752, 0.619, 0.512, 0.667, 0.559, 0.806, 0.727, 0.789, 0.815]\n",
      "epoch 17:\n",
      " Loss:0.684247732 \tAccuracy: 0.6846 [0.731, 0.866, 0.677, 0.404, 0.679, 0.527, 0.731, 0.713, 0.781, 0.737]\n",
      "epoch 18:\n",
      " Loss:0.658915233 \tAccuracy: 0.6928 [0.707, 0.732, 0.53, 0.458, 0.648, 0.56, 0.802, 0.798, 0.836, 0.857]\n",
      "epoch 19:\n",
      " Loss:0.642962313 \tAccuracy: 0.6779 [0.738, 0.714, 0.642, 0.267, 0.681, 0.671, 0.688, 0.723, 0.869, 0.786]\n",
      "epoch 20:\n",
      " loss:0.330649751 \tprocess: 43.58%\t   "
     ]
    }
   ],
   "source": [
    "train_and_test(\n",
    "    mynetwork=cnn_cifar10,\n",
    "    train_data=cifar10_train_data,\n",
    "    test_data=cifar10_test_data,\n",
    "    data_shape=cifar10_shape,\n",
    "    onehotsize=10,\n",
    "    batch_size=16,\n",
    "    epoch_num=50,\n",
    "    half_learning_rate_time=5,\n",
    "    pth=\"cnn_cifar10.npy\"\n",
    ")\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "face",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
